{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8a34796c",
   "metadata": {},
   "source": [
    "## 数据检索\n",
    "用户输入只是问题，答案需要从已知的数据库或者文档中获取。因此，我们需要使用检索器获取上下文，并通过“question”键下的用户输入。\n",
    "\n",
    "### RAG\n",
    "检索增强生成（Retrieval-augmented Generation，RAG），是当下最热门的大模型前沿技术之一。如果将“微调（finetune）”理解成大模型内化吸收知识的过程，那么RAG就相当于给大模型装上了“知识外挂”，基础大模型不用再训练即可随时调用特定领域知识。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b1dce1f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\a4154\\AppData\\Roaming\\Python\\Python310\\site-packages\\transformers\\utils\\hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
      "  warnings.warn(\n",
      "d:\\ProgramData\\anaconda3\\lib\\site-packages\\torch\\_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n"
     ]
    }
   ],
   "source": [
    "from langchain_community.vectorstores import FAISS\n",
    "from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n",
    "\n",
    "embeddings_path = \"D:\\\\ai\\\\download\\\\bge-large-zh-v1.5\"\n",
    "embeddings = HuggingFaceEmbeddings(model_name=embeddings_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "93c58daf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<langchain_community.vectorstores.faiss.FAISS at 0x1e30d24b910>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vectorstore = FAISS.from_texts(\n",
    "    [\"小明在华为工作\",\"熊喜欢吃蜂蜜\"],\n",
    "    embedding=embeddings\n",
    ")\n",
    "vectorstore\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "38c41de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用向量数据库生成检索器\n",
    "retriever = vectorstore.as_retriever()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "be0b596f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Document(page_content='熊喜欢吃蜂蜜'), Document(page_content='小明在华为工作')]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "retriever.invoke(\"熊喜欢吃什么？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a8cbf7f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI, OpenAI\n",
    "openai_api_key = \"EMPTY\"\n",
    "openai_api_base = \"http://127.0.0.1:1234/v1\"\n",
    "model = ChatOpenAI(\n",
    "    openai_api_key=openai_api_key,\n",
    "    openai_api_base=openai_api_base,\n",
    "    temperature=0.3,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c7abdb3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "\n",
    "template =\"\"\"\n",
    "只根据以下文档回答问题：\n",
    "{context}\n",
    "\n",
    "问题：{question}\n",
    "\"\"\"\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3b445a01",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnableParallel,RunnablePassthrough\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "outputParser = StrOutputParser()\n",
    "\n",
    "setup_and_retrieval = RunnableParallel(\n",
    "    {\n",
    "        \"context\":retriever,\n",
    "        \"question\":RunnablePassthrough()\n",
    "    }\n",
    ")\n",
    "\n",
    "chain = setup_and_retrieval | prompt | model | outputParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b95952b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'小明在华为工作。'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(\"小明在哪里工作？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cb888747",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'小明在华为工作。'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#完整链式\n",
    "chain = (\n",
    "    {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")\n",
    "\n",
    "chain.invoke(\"小明在哪里工作？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "deb624a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'主人，小明在华为工作。'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from operator import itemgetter\n",
    "\n",
    "template =\"\"\"\n",
    "只根据以下文档回答问题：\n",
    "{context}\n",
    "\n",
    "问题：{question}\n",
    "回答问题请加上称呼\"{name}\"。\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "chain = (\n",
    "    {\n",
    "        \"context\": itemgetter(\"question\") | retriever,\n",
    "        \"question\": itemgetter(\"question\"),\n",
    "        \"name\": itemgetter(\"name\"),\n",
    "    }\n",
    "    | prompt\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")\n",
    "\n",
    "chain.invoke({\"question\":\"小明在哪里工作？\",\"name\":\"主人\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a48669b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import DirectoryLoader\n",
    "\n",
    "loader = DirectoryLoader('./txt')\n",
    "\n",
    "docs = loader.load()\n",
    "docs"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
